-
Notifications
You must be signed in to change notification settings - Fork 418
[Feature] A2C objective class and train example #680
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
vmoens
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would you mind merging main and trying to solve the issues with the new "next" logic? Let me also know what you think of it :)
* init * strict=False * amend * amend
* Add auto-compute stats feature for ObservationNorm * Fix issue in ObservNorm init function * Quick refactor of ObservationNorm init method * Minor refactoring and adding more tests for ObservationNorm * lint * docstring * docstring Co-authored-by: vmoens <vincentmoens@gmail.com>
* init * [Feature] Nested composite spec (pytorch#654) * [Feature] Move `transform.forward` to `transform.step` (pytorch#660) * transform step function * amend * amend * amend * amend * amend * fixing key names * fixing key names * [Refactor] Transform next remove (pytorch#661) * Refactor "next_" into ("next", ) (pytorch#673) * amend * amend * bugfix * init * strict=False * strict=False * minor * amend * [BugFix] Use GitHub for flake8 pre-commit hook (pytorch#679) * amend * [BugFix] Update to strict select (pytorch#675) * init * strict=False * amend * amend * [Feature] Auto-compute stats for ObservationNorm (pytorch#669) * Add auto-compute stats feature for ObservationNorm * Fix issue in ObservNorm init function * Quick refactor of ObservationNorm init method * Minor refactoring and adding more tests for ObservationNorm * lint * docstring * docstring Co-authored-by: vmoens <vincentmoens@gmail.com> * amend * amend * lint * bf * bf * amend Co-authored-by: Romain Julien <romainjulien@fb.com> Co-authored-by: Romain Julien <romainjulien@fb.com>
|
Done! I brought all the changes from main, and now the training script calculates the initial Stats with the key "observation_vector" instead of "next_observation_vector". It should be the same since it is actually the same tensor delayed by 1 timestep. I also checked that the example script runs without issues. |
|
It feels like you merged each diff independently, which makes a gigantic diff here (over 50 files changed) |
* init * [Feature] Nested composite spec (pytorch#654) * [Feature] Move `transform.forward` to `transform.step` (pytorch#660) * transform step function * amend * amend * amend * amend * amend * fixing key names * fixing key names * [Refactor] Transform next remove (pytorch#661) * Refactor "next_" into ("next", ) (pytorch#673) * amend * amend * bugfix * init * strict=False * strict=False * minor * amend * [BugFix] Use GitHub for flake8 pre-commit hook (pytorch#679) * amend * [BugFix] Update to strict select (pytorch#675) * init * strict=False * amend * amend * [Feature] Auto-compute stats for ObservationNorm (pytorch#669) * Add auto-compute stats feature for ObservationNorm * Fix issue in ObservNorm init function * Quick refactor of ObservationNorm init method * Minor refactoring and adding more tests for ObservationNorm * lint * docstring * docstring Co-authored-by: vmoens <vincentmoens@gmail.com> * amend * amend * lint * bf * bf * amend Co-authored-by: Romain Julien <romainjulien@fb.com> Co-authored-by: Romain Julien <romainjulien@fb.com>
* amend * amend * amend * amend * amend * amend
vmoens
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM overall. The lint test is failing, a pre-commit should solve that. We should consider adding this to the example test pipeline (#687). After that I think we'll be good to go!
Codecov Report
@@ Coverage Diff @@
## main #680 +/- ##
==========================================
+ Coverage 87.78% 87.88% +0.09%
==========================================
Files 119 120 +1
Lines 20201 20590 +389
==========================================
+ Hits 17733 18095 +362
- Misses 2468 2495 +27
Flags with carried forward coverage won't be shown. Click here to find out more.
📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more |
vmoens
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a couple of minor changes and we're good to go!
Description
Added an A2C objective class.
I also created the helper functions necessary to run an A2C example, including make_a2c_loss, A2CLossConfig, make_a2c_model, A2CModelConfig
Creating a make_a2c_model helper function was not strictly necessary since the models are the same as in PPO. However, I wanted to use less nodes in the hidden layers so I decided to create a make_a2c_model instead of modifying the make_ppo_model. The methods can probably be merged in the future if necessary, and the architecture of the networks can be passed as a parameter.
Finally, I played a bit with the parameters int he canfig.yaml file until I found a good enough configuration that learned pretty well in the HalfCheetah-v4 environment.
Motivation and Context
There is an open issue about A2C, and while it is similar to REINFORCE and PPO which are already in the repo, the objective is not the same. In particular, it has the entropy term (which is not present in REINFORCE) and it does not have the log prob ratio weighting term, the clipping and the KL term present in PPO.
Types of changes
What types of changes does your code introduce? Remove all that do not apply:
Checklist
Go over all the following points, and put an
xin all the boxes that apply.If you are unsure about any of these, don't hesitate to ask. We are here to help!